#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 19 13:28:34 2017

@author: Kevin Liang

Tests
"""

import sys
sys.path.append('../')

import tensorflow as tf
import numpy as np

from Lib.TensorBase.tensorbase.base import Data
from Networks.convnet import convnet
from Networks.resnet import resnet


class faster_rcnn_tests():
    def __init__(self):
        self.x = tf.placeholder(tf.float32, [None, 128, 128, 3], name='x')
        self.sess = tf.InteractiveSession()
        
    def test_all(self):
        self.test_convnet_dims()
        self.test_resnet_dims()
        
    def test_convnet_dims(self):
        filter_sizes = (3,3,3,3)
        output_channels = (32,64,64,128)
        strides = (1,2,1,2)
        cnn = convnet(self.x, filter_sizes, output_channels, strides)
        featureMaps = cnn.get_output()
        
        init = tf.global_variables_initializer()
        self.sess.run(init)
                
        test_image = np.random.randint(0,256,[1,128,128,3])
        feat_val = self.sess.run(featureMaps,feed_dict={self.x:test_image})
        feat_val = np.array(feat_val)
        
        print(feat_val.shape)
        assert np.all(feat_val.shape == np.array([1,32,32,128]))
    
    def test_resnet_dims(self):
        cnn = resnet(50, self.x)
        featureMaps = cnn.get_output()
        
        init = tf.global_variables_initializer()
        self.sess.run(init)
        
        test_image = np.random.randint(0,256,[1,128,128,3])
        feat_val = self.sess.run(featureMaps,feed_dict={self.x:test_image})
        feat_val = np.array(feat_val)
        
        print(feat_val.shape)
        assert np.all(feat_val.shape == np.array([1,4,4,2048]))

    def print_test_image(self):
        """ Takes in a .tfrecord file and plots the image batch with bounding box """
        file = '/home/dcs41/Documents/tf-Faster-RCNN/Data/data_clutter/clutter_mnist_valid.tfrecords'
        im_dims, gt_boxes, image = Data.batch_inputs(self.read_and_decode, file, batch_size=32)
        self.sess.run(tf.local_variables_initializer())
        self.sess.run(tf.global_variables_initializer())
        threads, coord = Data.init_threads(self.sess)
        _, gt_boxes, image_out = self.sess.run([im_dims, gt_boxes, image])
        self.plot_img(image_out[0], gt_boxes[0])
        Data.exit_threads(threads, coord)

    @staticmethod
    def plot_img(image, gt_box):
        """ Takes an image and bounding box coordinates and displays it using matplotlib """

        # First print out image metrics
        print("Using First Image of the Batch..")
        print("Image Max Value (should be less than 1): %f" % image.max())
        print("Image Min Value (should be greater than -1): %f" % image.min())
        print("Image Mean Value (should be equal to 0): %f" % image.mean())
        print("Digit: %d" % gt_box[4])

        # Import matplotlib and setup axes
        import matplotlib
        matplotlib.use('TkAgg')  # For Mac OS
        import matplotlib.pyplot as plt
        import matplotlib.patches as patches
        fig, ax = plt.subplots(1)

        # Plot Image First
        ax.imshow(np.squeeze(image), cmap="gray")

        # Calculate Bounding Box Rectangle and plot it
        width = gt_box[3] - gt_box[1]
        height = gt_box[2] - gt_box[0]
        rect = patches.Rectangle((gt_box[1], gt_box[0]), height, width, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

        # Display Final composite image
        plt.show()

    @staticmethod
    def read_and_decode(example_serialized):
        """ Read and decode binarized, raw MNIST dataset from .tfrecords file generated by MNIST.py """
        features = tf.parse_single_example(
            example_serialized,
            features={
                'image': tf.FixedLenFeature([], tf.string),
                'gt_boxes': tf.FixedLenFeature([5], tf.int64, default_value=[-1]*5),  # 10 classes in MNIST
                'dims': tf.FixedLenFeature([2], tf.int64, default_value=[-1]*2)
            })
        # now return the converted data
        gt_boxes = features['gt_boxes']
        dims = features['dims']
        image = tf.decode_raw(features['image'], tf.float32)
        image = tf.reshape(image, [128, 128])
        return dims, gt_boxes, image
        
def main():
    print("Initiating Tests")
    tester = faster_rcnn_tests()
    tester.print_test_image()

if __name__ == "__main__":
    main()